
from __future__ import print_function
from argparse import ArgumentParser
import cv2
import csv
import os.path
import numpy as np
import torch
from torch.optim import Adam, lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from criterion import CrossEntropyLoss2d
from datasets import CD2014
from datasets import levir
import sys
#sys.path.append("./correlation_package/build/lib.linux-x86_64-3.5")
import cscdnet
import utils.transforms as trans


os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def colormap():
    cmap=np.zeros([2, 3]).astype(np.uint8)

    cmap[0,:] = np.array([0, 0, 0])
    cmap[1,:] = np.array([255, 255, 255])

    return cmap


class Colorization:

    def __init__(self, n=2):
        self.cmap = colormap()
        self.cmap = torch.from_numpy(np.array(self.cmap[:n]))

    def __call__(self, gray_image):
        size = gray_image.size()
        color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)

        for label in range(0, len(self.cmap)):
            mask = gray_image[0] == label

            color_image[0][mask] = self.cmap[label][0]
            color_image[1][mask] = self.cmap[label][1]
            color_image[2][mask] = self.cmap[label][2]

        return color_image


class Training:
    def __init__(self, arguments):
        self.args = arguments
        self.icount = 0
        self.dn_save = os.path.join(self.args.checkpointdir,'cdnet','checkpointdir','set{}'.format(self.args.cvset))

    def train(self):

        self.color_transform = Colorization(2)

        dataset_type = self.args.dataset
        cdnet_path = self.args.datadir

        # Dataset loader for train and test
        if dataset_type == "cdnet":

            train_transform_det = trans.Compose([
                trans.Scale((512,768)),
            ])

            cdnet_TRAIN_DATA_PATH = os.path.join(cdnet_path, "dataset")
            cdnet_TRAIN_LABEL_PATH = os.path.join(cdnet_path, "dataset")
            cdnet_TRAIN_TXT_PATH = cdnet_path + "/dataset/supply.txt"

            dataset_train = DataLoader(
                CD2014.Dataset(cdnet_TRAIN_DATA_PATH,cdnet_TRAIN_LABEL_PATH,
                        cdnet_TRAIN_TXT_PATH,'train',transform=True,transform_med = train_transform_det),
                num_workers=self.args.num_workers, batch_size=self.args.batch_size, shuffle=True)

        if dataset_type == "levir":
            train_transform_det = trans.Compose([
                trans.Scale((512, 512)),
            ])

            cdnet_TRAIN_DATA_PATH = os.path.join(cdnet_path, "dataset")
            cdnet_TRAIN_LABEL_PATH = os.path.join(cdnet_path, "dataset")
            cdnet_TRAIN_TXT_PATH = cdnet_path + "/dataset/train.txt"

            dataset_train = DataLoader(
                levir.Dataset(cdnet_TRAIN_DATA_PATH, cdnet_TRAIN_LABEL_PATH,
                               cdnet_TRAIN_TXT_PATH, 'train', transform=True, transform_med=train_transform_det),
                num_workers=self.args.num_workers, batch_size=self.args.batch_size, shuffle=True)

        self.test_path = os.path.join(self.dn_save, 'test')
        if not os.path.exists(self.test_path):
            os.makedirs(self.test_path)

        # Set loss function, optimizer and learning rate
        weight = torch.ones(2)
        criterion = CrossEntropyLoss2d(weight.cuda())
        optimizer = Adam(self.model.parameters(), lr=0.0001, betas=(0.5, 0.999))
        lambda1 = lambda icount: (float)(self.args.max_iteration - icount) / (float)(self.args.max_iteration)
        model_lr_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

        fn_loss = os.path.join(self.dn_save,'loss.csv')
        f_loss =  open(fn_loss, 'w')
        writer = csv.writer(f_loss)

        self.writers= SummaryWriter(os.path.join(self.dn_save, 'log'))

        # Training loop
        icount_loss = []
        while self.icount < self.args.max_iteration:
            model_lr_scheduler.step()
            for step, (inputs_train, mask_train) in enumerate(dataset_train):
                inputs_train = inputs_train.cuda()
                mask_train = mask_train.cuda()

                inputs_train = Variable(inputs_train)
                mask_train = Variable(mask_train)
                outputs_train, feature_maps = self.model(inputs_train)

                optimizer.zero_grad()
                self.loss = criterion(outputs_train, mask_train[:, 0])

                self.loss.backward()
                optimizer.step()

                self.icount += 1
                icount_loss.append(self.loss.item())
                print("self.icount: ", self.icount)
                writer.writerow([self.icount, self.loss.item()])

                if self.args.icount_save > 0 and self.icount % self.args.icount_save == 0:
                    self.checkpoint()

        f_loss.close()


    # Output results for tensorboard
    def log_tbx(self, image):

        writer = self.writers
        writer.add_scalar('data/loss', self.loss.item(), self.icount)
        writer.add_image('change detection', image, self.icount)

    def checkpoint(self):
        if self.args.use_corr:
            filename = 'cscdnet-{0:08d}.pth'.format(self.icount)
        else:
            filename = 'cdnet-{0:08d}.pth'.format(self.icount)
        torch.save(self.model.state_dict(), os.path.join(self.dn_save, filename))
        print('save: {0} (iteration: {1})'.format(filename, self.icount))

    def run(self):

        self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True)
        self.model = self.model.cuda()
        self.train()


if __name__ == '__main__':

    parser = ArgumentParser(description='Start training ...')
    parser.add_argument('--checkpointdir', required=True)
    parser.add_argument('--datadir', required=True)
    parser.add_argument('--dataset', required=True)
    parser.add_argument('--use-corr', action='store_true', help='using correlation layer')
    parser.add_argument('--max-iteration', type=int, default=50000)
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--cvset', type=int, default=0)
    parser.add_argument('--icount-save', type=int, default=10)

    training = Training(parser.parse_args())
    training.run()
